{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# First-visit MC prediction with blackjack game\n",
    "\n",
    "To understand this section clearly, you can recap first visit Monte Carlo method we\n",
    "learned earlier. Let's now understand how to implement the first-visit MC prediction with\n",
    "the blackjack game step by step:\n",
    "\n",
    "Import the necessary libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import pandas as pd\n",
    "from collections import defaultdict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a blackjack environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('Blackjack-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Defining a policy\n",
    "\n",
    "We learned that in the prediction method, we will be given an input policy and we predict\n",
    "the value function of the given input policy. So, now, we first define a policy function\n",
    "which acts as an input policy. That is, we define the input policy whose value function will\n",
    "be predicted in the upcoming steps.\n",
    "\n",
    "As shown below, our policy function takes the state as an input and if the `state[0]`, sum of\n",
    "our cards value is greater than 19, then it will return action 0 (stand) else it will return\n",
    "action 1 (hit):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy(state):\n",
    "    return 0 if state[0] > 19 else 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We defined an optimal policy, that is, it makes more sense to perform an action 0 (stand)\n",
    "when our sum value is already greater than 19. That is, when the sum value is greater than\n",
    "19 we don't have to perform 1 (hit) action and receive a new card which may cause us to\n",
    "lose the game or burst.\n",
    "\n",
    "For example, let's generate an initial state by resetting the environment as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(11, 6, False)\n"
     ]
    }
   ],
   "source": [
    "state = env.reset()\n",
    "print(state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can notice, `state[0] = 11`, that is our sum of cards value is 11, so in this case, our\n",
    "policy will return the action 1 (hit) as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n"
     ]
    }
   ],
   "source": [
    "print(policy(state))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, that we have defined the policy, in the next section, we will predict the value\n",
    "function (state values) of this policy. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating an episode\n",
    "Next, we generate an episode using the given policy, so, we, define a function\n",
    "called `generate_episode` which takes the policy as an input and generates the episode\n",
    "using the given policy.\n",
    "\n",
    "First, let's set the number of time steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_timestep = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_episode(policy):\n",
    "    \n",
    "    #let's define a list called episode for storing the episode\n",
    "    episode = []\n",
    "    \n",
    "    #initialize the state by resetting the environment\n",
    "    state = env.reset()\n",
    "    \n",
    "    #then for each time step\n",
    "    for i in range(num_timestep):\n",
    "        \n",
    "        #select the action according to the given policy\n",
    "        action = policy(state)\n",
    "        \n",
    "        #perform the action and store the next state information\n",
    "        next_state, reward, done, info = env.step(action)\n",
    "        \n",
    "        #store the state, action, reward into our episode list\n",
    "        episode.append((state, action, reward))\n",
    "        \n",
    "        #If the next state is a final state then break the loop else update the next state to the current state\n",
    "        if done:\n",
    "            break\n",
    "            \n",
    "        state = next_state\n",
    "\n",
    "    return episode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at how the output of our `generate_episode` function looks like. Note\n",
    "that we generate episode using the policy we defined earlier:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[((15, 10, False), 1, -1)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generate_episode(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can observe our output is in the form of `[(state, action, reward)]`. As shown above,\n",
    "we have two states in our episode. We performed action 1 (hit) in the state `(10, 2, False)` and received a 0 reward and the action 0 (stand) in the state `(20, 2, False)` and\n",
    "received 1.0 reward.\n",
    "\n",
    "Now that we have learned how to generate an episode using the given policy, next, we will\n",
    "look at how to compute the value of the state (value function) using first visit-MC\n",
    "method."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing the value function\n",
    "\n",
    "We learned that in order to predict the value function, we generate several episodes using\n",
    "the given policy and compute the value of the state as an average return across several\n",
    "episodes. Let's see how to do implement that.\n",
    "\n",
    "First, we define the `total_return` and `N` as a dictionary for storing the total return and the\n",
    "number of times the state is visited across episodes respectively. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_return = defaultdict(float)\n",
    "N = defaultdict(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set the number of iterations, that is, the number of episodes, we want to generate:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "#then, for every iteration\n",
    "for i in range(num_iterations):\n",
    "    \n",
    "    #generate the episode using the given policy, that is, generate an episode using the policy\n",
    "    #function we defined earlier\n",
    "    episode = generate_episode(policy)\n",
    "    \n",
    "    #store all the states, actions, rewards obtained from the episode\n",
    "    states, actions, rewards = zip(*episode)\n",
    "    \n",
    "    #then, for each step in the episode\n",
    "    for t, state in enumerate(states):\n",
    "        \n",
    "        #if the state is not visited already\n",
    "        if state not in states[0:t]:\n",
    "                \n",
    "            #compute the return R of the state as the sum of reward\n",
    "            R = (sum(rewards[t:]))\n",
    "            \n",
    "            #update the total_return of the state\n",
    "            total_return[state] =  total_return[state] + R\n",
    "            \n",
    "            #update the number of times the state is visited in the episode\n",
    "            N[state] =  N[state] + 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After computing the `total_return` and `N` We can just convert them into a pandas data\n",
    "frame for a better understanding. [Note that this is just to give a clear understanding of the\n",
    "algorithm, we don't necessarily have to convert to the pandas data frame, we can also\n",
    "implement this efficiently just using the dictionary]\n",
    "\n",
    "\n",
    "Convert `total_returns` dictionary to a data frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_return = pd.DataFrame(total_return.items(),columns=['state', 'total_return'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert the counter `N` dictionary to a data frame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = pd.DataFrame(N.items(),columns=['state', 'N'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Merge the two data frames on states:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.merge(total_return, N, on=\"state\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look at the first few rows of the data frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>state</th>\n",
       "      <th>total_return</th>\n",
       "      <th>N</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>(16, 3, False)</td>\n",
       "      <td>-53.0</td>\n",
       "      <td>98</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>(11, 9, False)</td>\n",
       "      <td>6.0</td>\n",
       "      <td>49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>(21, 9, False)</td>\n",
       "      <td>67.0</td>\n",
       "      <td>68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>(9, 8, False)</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>27</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>(16, 8, False)</td>\n",
       "      <td>-60.0</td>\n",
       "      <td>95</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>(17, 8, False)</td>\n",
       "      <td>-69.0</td>\n",
       "      <td>117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>(12, 5, False)</td>\n",
       "      <td>-38.0</td>\n",
       "      <td>91</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>(20, 5, False)</td>\n",
       "      <td>95.0</td>\n",
       "      <td>122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>(9, 6, False)</td>\n",
       "      <td>2.0</td>\n",
       "      <td>39</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>(14, 6, False)</td>\n",
       "      <td>-53.0</td>\n",
       "      <td>115</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            state  total_return    N\n",
       "0  (16, 3, False)         -53.0   98\n",
       "1  (11, 9, False)           6.0   49\n",
       "2  (21, 9, False)          67.0   68\n",
       "3   (9, 8, False)          -3.0   27\n",
       "4  (16, 8, False)         -60.0   95\n",
       "5  (17, 8, False)         -69.0  117\n",
       "6  (12, 5, False)         -38.0   91\n",
       "7  (20, 5, False)          95.0  122\n",
       "8   (9, 6, False)           2.0   39\n",
       "9  (14, 6, False)         -53.0  115"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can observe from above, we have the total return and\n",
    "the number of times the state is visited.\n",
    "\n",
    "Next, we can compute the value of the state as the average return, thus, we can write:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['value'] = df['total_return']/df['N']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the first few rows of the data frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>state</th>\n",
       "      <th>total_return</th>\n",
       "      <th>N</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>(16, 3, False)</td>\n",
       "      <td>-53.0</td>\n",
       "      <td>98</td>\n",
       "      <td>-0.540816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>(11, 9, False)</td>\n",
       "      <td>6.0</td>\n",
       "      <td>49</td>\n",
       "      <td>0.122449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>(21, 9, False)</td>\n",
       "      <td>67.0</td>\n",
       "      <td>68</td>\n",
       "      <td>0.985294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>(9, 8, False)</td>\n",
       "      <td>-3.0</td>\n",
       "      <td>27</td>\n",
       "      <td>-0.111111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>(16, 8, False)</td>\n",
       "      <td>-60.0</td>\n",
       "      <td>95</td>\n",
       "      <td>-0.631579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>(17, 8, False)</td>\n",
       "      <td>-69.0</td>\n",
       "      <td>117</td>\n",
       "      <td>-0.589744</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>(12, 5, False)</td>\n",
       "      <td>-38.0</td>\n",
       "      <td>91</td>\n",
       "      <td>-0.417582</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>(20, 5, False)</td>\n",
       "      <td>95.0</td>\n",
       "      <td>122</td>\n",
       "      <td>0.778689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>(9, 6, False)</td>\n",
       "      <td>2.0</td>\n",
       "      <td>39</td>\n",
       "      <td>0.051282</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>(14, 6, False)</td>\n",
       "      <td>-53.0</td>\n",
       "      <td>115</td>\n",
       "      <td>-0.460870</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            state  total_return    N     value\n",
       "0  (16, 3, False)         -53.0   98 -0.540816\n",
       "1  (11, 9, False)           6.0   49  0.122449\n",
       "2  (21, 9, False)          67.0   68  0.985294\n",
       "3   (9, 8, False)          -3.0   27 -0.111111\n",
       "4  (16, 8, False)         -60.0   95 -0.631579\n",
       "5  (17, 8, False)         -69.0  117 -0.589744\n",
       "6  (12, 5, False)         -38.0   91 -0.417582\n",
       "7  (20, 5, False)          95.0  122  0.778689\n",
       "8   (9, 6, False)           2.0   39  0.051282\n",
       "9  (14, 6, False)         -53.0  115 -0.460870"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": false
   },
   "source": [
    "As we can observe we now have the value of the state which is just the average of a return\n",
    "of the state across several episodes. Thus, we have successfully predicted the value function\n",
    "of the given policy using the first-visit MC method.\n",
    "\n",
    "Okay, let's check the value of some states and understand how accurately our value\n",
    "function is estimated according to the given policy. Recall that when we started off, to\n",
    "generate episodes, we used the optimal policy which selects action 0 (stand) when the sum\n",
    "value is greater than 19 and action 1 (hit) when the sum value is less than 19.\n",
    "\n",
    "Let's evaluate the value of the state `(21,9,False)`, as we can observe, our sum of cards\n",
    "value is already 21 and so this is a good state and should have a high value. Let's see what's\n",
    "our estimated value of the state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.98529412])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['state']==(21,9,False)]['value'].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can observe above our value of the state is high.\n",
    "Now, let's check the value of the state `(5,8,False)` as we can notice, our sum of cards\n",
    "value is just 5 and even the one dealer's single card has a high value, 8, then, in this case,\n",
    "the value of the state should be less. Let's see what's our estimated value of the state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.55555556])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['state']==(5,8,False)]['value'].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can notice, the value of the state is less.\n",
    "Thus, we learned how to predict the value function of the given policy using the first-visit\n",
    "MC prediction method. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
